Search Options

Results per page
Sort
Preferred Languages
Advance

Results 1 - 7 of 7 for TensorScatterUpdateOp (0.45 sec)

  1. tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h

          int64_t num_updates = indices_type.getDimSize(0);
          // For TF::TensorScatterUpdateOp, `indices` must have at least 2 axes:
          // `(num_updates, index_depth)`. Reshape indices and updates if necessary.
          if (std::is_same<TfOp, TF::TensorScatterUpdateOp>::value &&
              indices_type.getRank() == 1 && updates_type.getRank() == 1 &&
              index_depth == 1 && num_updates == 1) {
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 10.1K bytes
    - Viewed (0)
  2. tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.cc

        indices[i] = dimension_numbers.start_index_map()[i];
      }
    
      // Fill elements from start_indices with start_index_map
      Value scattered_start_indices = builder.create<TF::TensorScatterUpdateOp>(
          loc, empty_start_indices,
          /*indices=*/
          builder.create<TF::ReshapeOp>(
              loc, RankedTensorType::get({index_map_size, 1}, builder.getI64Type()),
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 13.2K bytes
    - Viewed (0)
  3. tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc

        TypeID::get<TF::StopGradientOp>(),
        TypeID::get<TF::StridedSliceOp>(),
        TypeID::get<TF::StridedSliceGradOp>(),
        TypeID::get<TF::SumOp>(),
        TypeID::get<TF::TanhGradOp>(),
        TypeID::get<TF::TensorScatterUpdateOp>(),
        TypeID::get<TF::TileOp>(),
        TypeID::get<TF::TopKV2Op>(),
        TypeID::get<TF::_UnaryOpsCompositionOp>(),
        TypeID::get<TF::UnsortedSegmentMaxOp>(),
        TypeID::get<TF::UnsortedSegmentMinOp>(),
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed Apr 24 04:08:35 UTC 2024
    - 21.7K bytes
    - Viewed (0)
  4. tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc

          : RewritePattern(
                InvertPermutationOp::getOperationName(), 1, context,
                {ConstOp::getOperationName(), RangeOp::getOperationName(),
                 ReshapeOp::getOperationName(),
                 TensorScatterUpdateOp::getOperationName()}) {}
    
      LogicalResult matchAndRewrite(Operation *src_op,
                                    PatternRewriter &rewriter) const override {
        auto op = cast<InvertPermutationOp>(src_op);
    
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 74.9K bytes
    - Viewed (0)
  5. tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc

      auto indices = builder.create<TF::ReshapeOp>(scatter.getLoc(),
                                                   scatter.getIndices(), shape);
      Value tensor_scatter_update = builder.create<TF::TensorScatterUpdateOp>(
          scatter.getLoc(), buffer, indices, scatter.getTensor());
      scatter.getOutputHandle().replaceAllUsesWith(tensor_scatter_update);
      scatter.erase();
      auto size = it->getSecond();
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 39.2K bytes
    - Viewed (0)
  6. tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc

      }
      return success();
    }
    
    //===----------------------------------------------------------------------===//
    // TensorScatterUpdateOp
    //===----------------------------------------------------------------------===//
    
    LogicalResult TensorScatterUpdateOp::verify() {
      TensorScatterUpdateOp op = *this;
      if (!HasRankAtLeast(op.getTensor(), 1))
        return op.emitOpError(
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu May 09 22:07:10 UTC 2024
    - 170.8K bytes
    - Viewed (0)
  7. tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc

        // updates.shape == indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
        if (updates_ty.getRank() == 0 &&
            (std::is_same<OpTy, TF::TensorScatterUpdateOp>::value ||
             std::is_same<OpTy, TF::TensorScatterAddOp>::value)) {
          if (!tensor_ty.hasStaticShape()) {
            return failure();
          }
    
          if (!indices_ty.hasStaticShape()) {
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Tue Jun 11 20:00:43 UTC 2024
    - 291.8K bytes
    - Viewed (0)
Back to top